#!/usr/bin/env python3
# I24 — No-Signalling with balanced blocks (deterministic present-act control)
#
# CONTROL (theory-faithful):
#  • Two wings A,B. At each tick t we choose local settings x_t,y_t ∈ {0,1} by a *deterministic 4-tick block*:
#        block = [(x,y)]=[(0,0),(1,0),(0,1),(1,1)] repeated ⇒ every block has one of each pair.
#  • Local outcomes are generated by *independent DDA-style counters* per wing & setting:
#        acc_i[s] ← acc_i[s] + rate_i[s]   (only when that setting s is active on wing i)
#        click_i(t)=1 and acc_i[s]←acc_i[s]−th_i[s] if acc_i[s] ≥ th_i[s], else 0.
#    (Boolean/ordinal; no RNG; remote setting never enters local eligibility.)
#
# DIAGNOSTICS (write-only):
#  • Verify exact block balance of (x,y) for every block.
#  • For each wing i and each local setting s: compute marginals
#        p_i(s|y=0) and p_i(s|y=1) and check |difference| ≤ marg_tol.
#  • Totals per setting per wing reported; optional stability checks.
#
# OUTPUT:
#  • metrics/i24_marginals.csv — counts & rates per wing, per setting, per remote setting.
#  • metrics/i24_blocks.csv — per-block counts of (00,10,01,11) to certify balance.
#  • audits/i24_audit.json — full summary + acceptance flags.
#  • run_info/result_line.txt — one-line PASS summary with worst-case marginal deltas.
#
import argparse, json, os, sys, csv, math
from typing import List, Dict, Tuple

def utc_timestamp():
    import datetime as dt
    return dt.datetime.now(dt.timezone.utc).strftime("%Y-%m-%dT%H-%M-%SZ")

def ensure_dirs(root, *subs):
    for s in subs:
        os.makedirs(os.path.join(root, s), exist_ok=True)

def write_text(path, text):
    with open(path, "w", encoding="utf-8") as f:
        f.write(text)

def dump_json(path, obj):
    with open(path, "w", encoding="utf-8") as f:
        json.dump(obj, f, indent=2, sort_keys=True)

# ---------- Deterministic present-act local generators ----------
class DDAGen:
    def __init__(self, th: int, rate: int, phase: int = 0):
        self.th = int(th)
        self.rate = int(rate)
        self.acc = int(phase) % self.th

    def step_active(self) -> int:
        """Advance accumulator by rate; fire (return 1) if acc >= th (with wrap), else 0."""
        self.acc += self.rate
        if self.acc >= self.th:
            self.acc -= self.th
            return 1
        return 0

def run_i24(M: Dict, outdir: str) -> Dict:
    H        = int(M["H"])                    # total ticks; must be multiple of 4
    assert H % 4 == 0, "H must be a multiple of 4 (balanced 4-tick blocks)."
    blocks   = H // 4
    ctrl     = M["control"]
    acc_pars = M["accumulators"]
    accA = [DDAGen(acc_pars["A"]["th"][0], acc_pars["A"]["rate"][0], acc_pars["A"]["phase"][0]),
            DDAGen(acc_pars["A"]["th"][1], acc_pars["A"]["rate"][1], acc_pars["A"]["phase"][1])]
    accB = [DDAGen(acc_pars["B"]["th"][0], acc_pars["B"]["rate"][0], acc_pars["B"]["phase"][0]),
            DDAGen(acc_pars["B"]["th"][1], acc_pars["B"]["rate"][1], acc_pars["B"]["phase"][1])]

    # Balanced 4-tick schedule: [(0,0),(1,0),(0,1),(1,1)] repeated
    schedule = [(0,0),(1,0),(0,1),(1,1)]

    # Tallies
    # counts[(wing, setting s, remote r)] -> (n_clicks, n_trials)
    counts = {("A",0,0):(0,0), ("A",0,1):(0,0), ("A",1,0):(0,0), ("A",1,1):(0,0),
              ("B",0,0):(0,0), ("B",0,1):(0,0), ("B",1,0):(0,0), ("B",1,1):(0,0)}
    block_counts = []  # list of dict { (x,y)->count in this block }
    total_A = [0,0]; total_B = [0,0]  # clicks per local setting
    # run
    for b in range(blocks):
        hist = {(0,0):0,(1,0):0,(0,1):0,(1,1):0}
        for k,(x,y) in enumerate(schedule):
            t = 4*b + k
            # local control
            clickA = accA[x].step_active()
            clickB = accB[y].step_active()
            # tallies
            ca = counts[("A",x,y)]
            counts[("A",x,y)] = (ca[0] + clickA, ca[1] + 1)
            cb = counts[("B",y,x)]
            counts[("B",y,x)] = (cb[0] + clickB, cb[1] + 1)
            total_A[x] += clickA
            total_B[y] += clickB
            hist[(x,y)] += 1
        block_counts.append({f"{k0}{k1}":hist[(k0,k1)] for (k0,k1) in [(0,0),(1,0),(0,1),(1,1)]})

    # Diagnostics
    marg_tol = float(M["acceptance"]["marg_tol"])
    block_tol = int(M["acceptance"]["block_tol"])

    # Block balance check
    block_ok = True
    for h in block_counts:
        if any(abs(h["00"]-1)>block_tol or abs(h["10"]-1)>block_tol or
               abs(h["01"]-1)>block_tol or abs(h["11"]-1)>block_tol):
            block_ok = False
            break

    # Marginals per wing, per local setting, conditioned on remote
    def rate(nc, nt): 
        return (nc/nt) if nt>0 else 0.0
    deltas = []  # collect |p(s|remote=0) - p(s|remote=1)| for each wing/setting
    marg = {"A":{}, "B":{}}
    for wing in ["A","B"]:
        for s in [0,1]:
            if wing=="A":
                n1,c1 = counts[(wing,s,0)]
                n2,c2 = counts[(wing,s,1)]
            else:
                # for B we stored counts under key ("B", y, x) with remote=x
                n1,c1 = counts[(wing,s,0)]
                n2,c2 = counts[(wing,s,1)]
            p0 = rate(n1,c1); p1 = rate(n2,c2)
            marg[wing][f"s{s}|r0"] = p0
            marg[wing][f"s{s}|r1"] = p1
            deltas.append(abs(p0-p1))

    worst_delta = max(deltas) if deltas else 0.0

    # Summaries
    audit = {
        "sim":"I24_no_signalling",
        "H": H,
        "blocks": blocks,
        "schedule":"(0,0),(1,0),(0,1),(1,1) repeated",
        "accumulators": M["accumulators"],
        "control": M["control"],
        "marginals": marg,
        "total_clicks": {"A": total_A, "B": total_B},
        "block_counts": block_counts[:16],   # first few blocks preview
        "worst_delta": worst_delta,
        "accept": M["acceptance"]
    }

    pass_blocks = block_ok
    pass_marg   = (worst_delta <= marg_tol)

    passed = bool(pass_blocks and pass_marg)

    audit["passed"] = passed
    dump_json(os.path.join(outdir,"outputs/audits","i24_audit.json"), audit)

    # Write metrics
    with open(os.path.join(outdir,"outputs/metrics","i24_marginals.csv"),"w",newline="",encoding="utf-8") as f:
        w=csv.writer(f)
        w.writerow(["wing","local_setting","remote_setting","clicks","trials","rate"])
        for wing in ["A","B"]:
            for s in [0,1]:
                for r in [0,1]:
                    if wing=="A":
                        nc,nt = counts[(wing,s,r)]
                    else:
                        nc,nt = counts[(wing,s,r)]
                    w.writerow([wing,s,r,nc,nt,(nc/nt if nt>0 else 0.0)])

    with open(os.path.join(outdir,"outputs/metrics","i24_blocks.csv"),"w",newline="",encoding="utf-8") as f:
        w=csv.writer(f); w.writerow(["block","n00","n10","n01","n11"])
        for b,h in enumerate(block_counts):
            w.writerow([b,h["00"],h["10"],h["01"],h["11"]])

    # Result line
    line = f"I24 PASS={passed} worst_delta={worst_delta:.5f} blocks_ok={pass_blocks}"
    write_text(os.path.join(outdir,"outputs/run_info","result_line.txt"), line)
    print(line)
    return audit

def main():
    ap=argparse.ArgumentParser()
    ap.add_argument("--manifest", required=True)
    ap.add_argument("--outdir", required=True)
    args=ap.parse_args()

    ensure_dirs(args.outdir,"config","outputs/metrics","outputs/audits","outputs/run_info","logs")
    M=json.load(open(args.manifest,"r",encoding="utf-8"))
    # Persist manifest & env
    dump_json(os.path.join(args.outdir,"config","manifest_i24.json"), M)
    write_text(os.path.join(args.outdir,"logs","env.txt"),
               f"utc={{utc_timestamp()}}\nos={{os.name}}\npython={{sys.version.split()[0]}}\n")

    run_i24(M, args.outdir)

if __name__ == "__main__":
    main()
